
from init import PATHS
import logging
LOGGER = logging.getLogger(__name__)
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('seaborn')
import seaborn as sns
sns.set(style="white")
pal = sns.color_palette('colorblind')
from C_PatentVariables import conventions_names_colors
dict_subsector_shortnames, dict_var_colors, list_of_subsectors_in_cleancars = conventions_names_colors.main()



def main():
    LOGGER.info('Begin c_othergraphs_OEMs.py')
    table_with_list_OEMs()
    ts_counts = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/TimeSeries_FamilyCounts_in_ipc_cpc_transpo_from_OEM_and_Subsidiaries.csv')
    OEMs_total_cleancars_vs_ICE(ts_counts)
    OEMs_total_cleancar_subsectors(ts_counts)
    sumstats_sales()
    ev_sales()
    maingraph_wH2_OEMs_total_Bat_vs_FC()
    LOGGER.info('END c_othergraphs_OEMs.py')


def table_with_list_OEMs():
    oems = pd.read_csv(PATHS.marklines / 'OEMs.csv')
    oems = oems[['OEM_Level1_ID', 'Level1_MLName', 'Level1_OrbisName']].drop_duplicates()
    oems = oems.sort_values(by=['OEM_Level1_ID', 'Level1_OrbisName'])
    oems['Level1_OrbisName'] = oems['Level1_OrbisName'].str.title()
    oems = oems.rename(columns={'OEM_Level1_ID': 'Carmaker ID', 'Level1_MLName': 'Markline Name', 'Level1_OrbisName': 'Orbis Name'})
    oems = oems.set_index(['Carmaker ID', 'Markline Name', 'Orbis Name'])
    pd.set_option('display.max_colwidth', None)
    df_tex = oems.to_latex(index=True)
    df_tex = df_tex.replace('   &               &                                     \\\\\n', '   ')
    df_tex = df_tex.replace('{lll}', '{cll}')
    with open(PATHS.tables / 'list_oems.tex', "w") as f:
        f.write(df_tex)


def sumstats_sales():
    firm_sales_by_country = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/sales_firm_year_country_level.csv')
    firm_sales_by_country = firm_sales_by_country[firm_sales_by_country['value'] > 0]
    data = firm_sales_by_country.groupby(['OEM_Level1_ID', 'Level1_MLName', 'Year']).sum()['value']
    data = data.groupby(['OEM_Level1_ID', 'Level1_MLName']).mean().rename('Mean Annual_Sales').sort_values(ascending=False).astype(int)
    data = data.reset_index()
    # Nbr Countries ...
    df = firm_sales_by_country.groupby(['OEM_Level1_ID', 'Level1_MLName', 'Country'])['value'].mean()
    total = df.groupby(['OEM_Level1_ID', 'Level1_MLName']).sum().rename('total')
    df = df.reset_index().merge(total, on=['OEM_Level1_ID', 'Level1_MLName'])
    df['share'] = df['value'] / df['total']
    df = df.sort_values(by=['OEM_Level1_ID', 'share'], ascending=False)
    df['sharecumsum'] = df.groupby(['OEM_Level1_ID', 'Level1_MLName'])['share'].cumsum()
    df['sharerank'] = df.groupby(['OEM_Level1_ID', 'Level1_MLName'])['share'].rank(ascending=False).astype(int)
    # Herfindalh
    df['share2'] = df['share'] * df['share']
    HHI = df.groupby(['OEM_Level1_ID', 'Level1_MLName']).sum()['share2'].rename('Geographic_Concentration')
    data = data.merge(HHI, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    # Mean Nbr Countries
    df1 = firm_sales_by_country.groupby(['OEM_Level1_ID', 'Level1_MLName', 'Year']).count()['Country']
    df1 = df1.groupby(['OEM_Level1_ID', 'Level1_MLName']).mean().rename('Mean Number_of Countries')
    data = data.merge(df1, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    # with 50%
    rank = df[df['sharecumsum'] >= .5].groupby(['OEM_Level1_ID', 'Level1_MLName']).min()['sharerank'].rename('threshold')
    df1 = df.merge(rank, on=['OEM_Level1_ID', 'Level1_MLName'])
    df1 = df1[df1['sharerank'] <= df1['threshold']]
    df1 = df1.groupby(['OEM_Level1_ID', 'Level1_MLName'])['Country'].count().rename('Mean Nbr Countries_with 50%')
    data = data.merge(df1, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    # with 80%
    rank = df[df['sharecumsum'] >= .8].groupby(['OEM_Level1_ID', 'Level1_MLName']).min()['sharerank'].rename('threshold')
    df1 = df.merge(rank, on=['OEM_Level1_ID', 'Level1_MLName'])
    df1 = df1[df1['sharerank'] <= df1['threshold']]
    df1 = df1.groupby(['OEM_Level1_ID', 'Level1_MLName'])['Country'].count().rename('Mean Nbr Countries_with 80%')
    data = data.merge(df1, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    # number of countries 2004
    maskyear = firm_sales_by_country['Year'] == 2004
    df = firm_sales_by_country[maskyear].groupby(['OEM_Level1_ID', 'Level1_MLName']).count()['Country'].rename('Nbr Countries_in 2004').astype(int)
    data = data.merge(df, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    # number of countries 2018
    maskyear = firm_sales_by_country['Year'] == 2018
    df = firm_sales_by_country[maskyear].groupby(['OEM_Level1_ID', 'Level1_MLName']).count()['Country'].rename('Nbr Countries_in 2018').astype(int)
    data = data.merge(df, on=['OEM_Level1_ID', 'Level1_MLName'], how='outer')
    data['Nbr Countries_in 2004'] = data['Nbr Countries_in 2004'].apply(lambda x: '{:.0f}'.format(x))
    data['Nbr Countries_in 2018'] = data['Nbr Countries_in 2018'].apply(lambda x: '{:.0f}'.format(x))
    # Formatting
    data = data.rename(columns={'OEM_Level1_ID': 'Carmaker ID', 'Level1_MLName': 'Name'})
    data['Name'] = data['Name'].replace('SAIC (Shanghai Automotive Industry Corporation (Group))', 'SAIC')
    data['Name'] = data['Name'].replace('Changan/Chana (Changan Automobile (Group))', 'Changan/Chana')
    data['Mean Annual_Sales'] = data['Mean Annual_Sales'].apply(lambda x: '{:,}'.format(x))
    data.columns = pd.MultiIndex.from_tuples([tuple(c.split('_')) for c in data.columns])
    data.to_latex(index=False)
    df_tex = data.to_latex(index=False, column_format='{clccccccc}')
    df_tex = df_tex.replace('\\multicolumn{2}{l}{Mean Nbr Countries}', '\\multicolumn{2}{c}{Mean Nbr Countries}')
    df_tex = df_tex.replace('\\multicolumn{2}{l}{Nbr Countries}', '\\multicolumn{2}{c}{Nbr Countries}')
    df_tex = df_tex.replace('NaN', ' ')
    df_tex = df_tex.replace('nan', ' ')
    with open(PATHS.tables / f'sum_stats_firms_sales.tex', "w") as f:
        f.write(df_tex)



def OEMs_total_cleancars_vs_ICE(ts_counts):
    ts_counts['Count_ICE_andEffICE_excl'] = ts_counts['Count_ICE_excl'] + ts_counts['Count_EffICE_excl']
    dict_var_colors['Count_ICE_andEffICE_excl'] = ['ICE + ICE Efficiency (combined)', pal[1]]
    maskYears = ts_counts['earliest_filing_year'].isin(range(1990, 2016))
    fig, ax = plt.subplots()
    for var in ['Count_CleanCar_excl', 'Count_ICE_excl', 'Count_EffICE_excl', 'Count_ICE_andEffICE_excl']:
        linestyle = '-' if var != 'Count_ICE_andEffICE_excl' else '--'
        plt.plot(ts_counts[maskYears]['earliest_filing_year'], ts_counts[maskYears][var], label=dict_var_colors[var][0], color=dict_var_colors[var][1], linewidth=3, linestyle=linestyle)
    plt.legend(ncol=1)
    plt.xlabel('Year')
    plt.ylabel('Number of DocDB Families')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_v_ICE.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_v_ICE.pdf', bbox_inches='tight')
    plt.close()




def OEMs_total_cleancar_subsectors(ts_counts):
    maskYears = ts_counts['earliest_filing_year'].isin(range(1990, 2016))
    fig, ax = plt.subplots()
    for var in ['Count_Bat_excl', 'Count_FC_excl', 'Count_HV_excl', 'Count_EV_excl', 'Count_H2_excl', 'Count_Stor_excl', 'Count_Biofuels_excl']:
        linestyle = '-' if var != 'Count_ICE_andEffICE_excl' else '--'
        plt.plot(ts_counts[maskYears]['earliest_filing_year'], ts_counts[maskYears][var], label=dict_var_colors[var][0], color=dict_var_colors[var][1], linewidth=3, linestyle=linestyle)
    plt.legend(ncol=1)
    plt.xlabel('Year')
    plt.ylabel('Number of DocDB Families')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_subsectors.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_subsectors.pdf', bbox_inches='tight')
    plt.yscale('log')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_subsectors_logscale.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_cleancars_subsectors_logscale.pdf', bbox_inches='tight')
    plt.close()



def ev_sales():
    sales_by_type = pd.read_csv(PATHS.dropbox / 'Data_outputted/A_AutoIndustry/sales_country_year_type_level.csv')
    sales_by_type = sales_by_type.groupby('Year')[['EV','FCV']].sum().reset_index()
    plt.plot(sales_by_type['Year'].values,sales_by_type['EV'].values,label = "BEV",color = dict_var_colors['Count_Bat_excl'][1], linewidth = 3)
    plt.plot(sales_by_type['Year'].values,sales_by_type['FCV'].values, label = 'FCEV',color = dict_var_colors['Count_FC_excl'][1],linewidth = 3)
    plt.yscale('log')
    plt.legend()
    plt.xlabel('Year')
    plt.ylabel('Number of cars sold gloablly')
    plt.savefig(PATHS.figures / 'oems' / 'BEV_FCEV_global_sales.png',bbox_inches = 'tight')
    plt.savefig(PATHS.figures / 'oems' / 'BEV_FCEV_global_sales.pdf',bbox_inches = 'tight')
    plt.close()



def maingraph_wH2_OEMs_total_Bat_vs_FC():
    ts_counts = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/TimeSeries_FamilyCounts_in_ipc_cpc_transpo_from_OEM_and_Subsidiaries.csv')
    maskYears = ts_counts['earliest_filing_year'].isin(range(1990, 2016))
    fig, ax = plt.subplots()
    for var in ['Count_Bat_excl_wH2', 'Count_FC_excl_wH2']:
        plt.plot(ts_counts[maskYears]['earliest_filing_year'], ts_counts[maskYears][var], label=dict_var_colors[var][0], color=dict_var_colors[var][1], linewidth=3)
    plt.legend(ncol=1)
    plt.xlabel('Year')
    plt.ylabel('Number of Patent Families')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_Bat_v_FC_wH2.png', bbox_inches='tight')
    plt.savefig(PATHS.figures / 'oems' / 'patenting_total_OEM_Bat_v_FC_wH2.pdf', bbox_inches='tight')
    plt.close()